#!/usr/bin/env python  
#-*- coding:utf-8 _*-  
""" 
@author:hello_life 
@license: Apache Licence 
@file: main_2.py 
@time: 2022/05/09
@software: PyCharm 
description:
"""
import torch
from torch.utils.data import DataLoader

from model.albert import Albert_NER
from parameters.bert_ner_config import Config
from utils.train_utils import train_loop
from utils.CLUE_dataset import CLUE_Dataset,collate_fn

if __name__ == '__main__':
    config=Config()
    print(f"Training on {config.device}")

    #加载data
    train_data=CLUE_Dataset(config)
    train_data_loader=DataLoader(train_data,batch_size=config.batch_size,
                                 collate_fn=collate_fn,shuffle=False)

    test_data=CLUE_Dataset(config)
    test_data_loader=DataLoader(test_data,batch_size=config.batch_size,
                                 collate_fn=collate_fn,shuffle=False)

    #加载模型
    model=Albert_NER(config).to(config.device)

    train_loop(model,train_data_loader,test_data_loader,config)
